{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AgentOptimizer: An Agentic Way to Train Your LLM Agent\n",
    "\n",
    "AutoGen offers conversable agents powered by LLM, tool, or human, which can be used to perform tasks collectively via automated chat. This framework allows tool use and human participation through multi-agent conversation.\n",
    "Please find documentation about this feature [here](https://microsoft.github.io/autogen/docs/Use-Cases/agent_chat).\n",
    "\n",
    "In traditional ML pipeline, we train a model by updating its parameter according to the loss on the training set, while in the era of LLM agents, how should we train an agent? Here, we take an initial step towards the agent training. Inspired by the [function calling](https://platform.openai.com/docs/guides/function-calling) capabilities provided by OpenAI, we draw an analogy between model parameters and agent functions/skills, and update agent’s functions/skills based on its historical performance on the training set. As an agentic way of training an agent, our approach help enhance the agents’ abilities without requiring access to the LLMs parameters.\n",
    "\n",
    "In this notebook, we introduce a new class, ‘AgentOptimizer’, which is able to improve the function list of one Assistant-UserProxy pair according to the historical conversation histories. This feature would support agents in improving their ability to solve problems of the same type as previous tasks.\n",
    "Specifically, given a set of training data, AgentOptimizer would iteratively prompt the LLM to optimize the existing function list of the AssistantAgent and UserProxyAgent with code implementation if necessary.\n",
    "In the example scenario, we test the proposed AgentOptimizer in solving problems from the [MATH dataset](https://github.com/hendrycks/math). \n",
    "\n",
    "Paper is coming soon!\n",
    "\n",
    "![AgentEval](../website/blog/2023-12-23-AgentOptimizer/img/agentoptimizer.png)\n",
    "\n",
    "Authors:\n",
    "- [Shaokun Zhang](https://github.com/skzhang1), Ph.D. student at the The Pennsylvania State University\n",
    "- [Jieyu Zhang](https://jieyuz2.github.io), Ph.D. student at the University of Washington"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "from typing import Any, Callable, Dict, List, Optional, Tuple, Union\n",
    "\n",
    "from openai import AzureOpenAI, BadRequestError\n",
    "\n",
    "import autogen\n",
    "from autogen.agentchat import Agent\n",
    "from autogen.agentchat.contrib.math_user_proxy_agent import MathUserProxyAgent\n",
    "from autogen.code_utils import execute_code, extract_code\n",
    "from autogen.math_utils import get_answer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AgentOptimizer\n",
    "\n",
    "AgentOptimizer is a class that is designed to improve the agents through optimizing its function call. It contains two core methods:\n",
    "\n",
    "1. `step()`: `step()` has three inputs: previous conversation history (history), the statistical information of solving previous problems (statistic), and the signature of current functions (func_signature). The output is a series of actions to manipulate the current functions.\n",
    "\n",
    "2. `update_function_call()`: This method updates the functions registered in the agents according to the actions from `step()`.   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AgentOptimizer:\n",
    "    OPT_PROMPT = \"\"\"You are a function optimizer. Your task is to maintain a list of functions for the assistant according to the existing function list and conversation history that happens between the assistant and the user.\n",
    "You can perform one of the following four actions to manipulate the function list using the functions you have:\n",
    "1. Revise one existing function (using revise_function).\n",
    "2. Remove one existing function (using remove_function).\n",
    "3. Add one new function (using add_function).\n",
    "4. Directly return \"TERMINATE\" to me if no more actions are needed for the current function list.\n",
    "\n",
    "Below are the principles that you need to follow for taking these four actions.\n",
    "(1) Revise one existing function:\n",
    "1. Pay more attention to the failed tasks and corresponding error information, and optimize the function used in these tasks according to the conversation history if needed.\n",
    "2. A failed function call can occur due to incorrect input arguments (missing arguments) or an incorrect function code implementation. You should focus more on the function code implementation and make it easy to get success function call.\n",
    "3. Do not revise the function that you think works well and plays a critical role in solving the problems according to the conversation history. Only making revisions if needed.\n",
    "4. Sometimes, a NameError may occur. To fix this error, you can either revise the name of the function in the code implementation or revise the name of the function call to make these two names consistent.\n",
    "(2) Remove one existing function:\n",
    "1. Only remove the function that you think is not needed anymore in future tasks.\n",
    "(3) Add one new function:\n",
    "1. The added new function should solve a higher-level question that encompasses the original query and extend the code's functionality to make it more versatile and widely applicable.\n",
    "2. The added new function should solve queries of the same type, based on common reasoning steps without mentioning specific object names or entity terms.\n",
    "3. Name the function and write the description concerning both the core reasoning pattern and data organization format, without referencing specific objects. The name of the function MUST be the same with the function name in the code you generated.\n",
    "4. Replace specific strings or variable names with general variables to enhance the tool's applicability to various queries. All names used inside the function should be passed in as arguments.\n",
    "(4) Directly return \"TERMINATE\":\n",
    "If you think there is no need to perform any other actions for the current function list since the current list is optimal more actions will harm the performance in future tasks. Please directly reply to me with \"TERMINATE\".\n",
    "\n",
    "One function signature includes the following five elements:\n",
    "1. Function name\n",
    "2. Function description\n",
    "3. JSON schema of arguments encoded as a string\n",
    "4. A list of package names imported by the function packages\n",
    "5. The code implementation\n",
    "\n",
    "Below are the signatures of the current functions:\n",
    "List A: {signiture}.\n",
    "The success rate (performance) with this function list is {success_rate}.\n",
    "The following list are the function signatures that you have after taking {actions_num} actions in our previous conversations:\n",
    "List B: {after_signiture}.\n",
    "Here are {conversation_num} conversation histories of solving {conversation_num} tasks.\n",
    "History:\n",
    "{history}\n",
    "The following table shows the statistical information for solving each task in each conversation and indicates whether each task was successfully solved.\n",
    "1 represents correct. 0 represents wrong.\n",
    "statistic:\n",
    "{statistic}\n",
    "\n",
    "According to the information I provide, please take one of four actions to manipulate list B using the functions you know.\n",
    "    \"\"\"\n",
    "\n",
    "    ADD_FUNC = {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"add_function\",\n",
    "            \"description\": \"Add a function in the context of the conversation. Necessary Python packages must be declared. The name of the function MUST be the same with the function name in the code you generated.\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"name\": {\"type\": \"string\", \"description\": \"The name of the function in the code implementation.\"},\n",
    "                    \"description\": {\"type\": \"string\", \"description\": \"A short description of the function.\"},\n",
    "                    \"arguments\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": 'JSON schema of arguments encoded as a string. Please note that the JSON schema only supports specific types including string, integer, object, array, boolean. (do not have float type) For example: { \"url\": { \"type\": \"string\", \"description\": \"The URL\", }}. Please avoid the error \\'array schema missing items\\' when using array type.',\n",
    "                    },\n",
    "                    \"packages\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"A list of package names imported by the function, and that need to be installed with pip prior to invoking the function. This solves ModuleNotFoundError. It should be string, not list.\",\n",
    "                    },\n",
    "                    \"code\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The implementation in Python. Do not include the function declaration.\",\n",
    "                    },\n",
    "                },\n",
    "                \"required\": [\"name\", \"description\", \"arguments\", \"packages\", \"code\"],\n",
    "            },\n",
    "        },\n",
    "    }\n",
    "\n",
    "    REVISE_FUNC = {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"revise_function\",\n",
    "            \"description\": \"Revise a function in the context of the conversation. Necessary Python packages must be declared. The name of the function MUST be the same with the function name in the code you generated.\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"name\": {\"type\": \"string\", \"description\": \"The name of the function in the code implementation.\"},\n",
    "                    \"description\": {\"type\": \"string\", \"description\": \"A short description of the function.\"},\n",
    "                    \"arguments\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": 'JSON schema of arguments encoded as a string. Please note that the JSON schema only supports specific types including string, integer, object, array, boolean. (do not have float type) For example: { \"url\": { \"type\": \"string\", \"description\": \"The URL\", }}. Please avoid the error \\'array schema missing items\\' when using array type.',\n",
    "                    },\n",
    "                    \"packages\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"A list of package names imported by the function, and that need to be installed with pip prior to invoking the function. This solves ModuleNotFoundError. It should be string, not list.\",\n",
    "                    },\n",
    "                    \"code\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The implementation in Python. Do not include the function declaration.\",\n",
    "                    },\n",
    "                },\n",
    "                \"required\": [\"name\", \"description\", \"arguments\", \"packages\", \"code\"],\n",
    "            },\n",
    "        },\n",
    "    }\n",
    "\n",
    "    REMOVE_FUNC = {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"remove_function\",\n",
    "            \"description\": \"Remove one function in the context of the conversation. Once remove one function, the assistant will not use this function in future conversation.\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"name\": {\"type\": \"string\", \"description\": \"The name of the function in the code implementation.\"}\n",
    "                },\n",
    "                \"required\": [\"name\"],\n",
    "            },\n",
    "        },\n",
    "    }\n",
    "\n",
    "    def __init__(self, OAI_config, action_num=3, each_action_max_trials=10):\n",
    "        self._action_num = action_num\n",
    "        self._each_action_max_trials = each_action_max_trials\n",
    "        os.environ[\"AZURE_OPENAI_API_KEY\"] = OAI_config[\"AZURE_OPENAI_API_KEY\"]  # TODO: input key into client\n",
    "        self._client = AzureOpenAI(\n",
    "            api_version=OAI_config[\"api_version\"],\n",
    "            azure_endpoint=OAI_config[\"azure_endpoint\"],\n",
    "        )\n",
    "        self.model = \"gpt-4-1106-preview\"\n",
    "\n",
    "    def _val_json(self, actions):\n",
    "        if actions is None:\n",
    "            return True\n",
    "        else:\n",
    "            for action in actions:\n",
    "                function_args = action.function.arguments\n",
    "                try:\n",
    "                    function_args = json.loads(function_args.strip('\"'))\n",
    "                    if \"arguments\" in function_args.keys():\n",
    "                        json.loads(function_args.get(\"arguments\").strip('\"'))\n",
    "                except Exception as e:\n",
    "                    print(\"JSON is invalid:\", e)\n",
    "                    return False\n",
    "        return True\n",
    "\n",
    "    def _val_remove(self, actions, after_signiture):\n",
    "        if actions is None:\n",
    "            return True\n",
    "        else:\n",
    "            for action in actions:\n",
    "                action_name = action.function.name\n",
    "                if action_name == \"remove_function\":\n",
    "                    function_args = json.loads(action.function.arguments.strip('\"'))\n",
    "                    if function_args.get(\"name\") not in [item[\"name\"] for item in after_signiture]:\n",
    "                        print(\"The function you want to remove does not exist.\")\n",
    "                        return False\n",
    "            return True\n",
    "\n",
    "    def _val_syntax(self, actions):\n",
    "        if actions is None:\n",
    "            return True\n",
    "        else:\n",
    "            for action in actions:\n",
    "                if action.function.name != \"remove_function\":\n",
    "                    function_args = json.loads(action.function.arguments.strip('\"'))\n",
    "                    code = function_args.get(\"code\")\n",
    "                    try:\n",
    "                        compile(code, \"<string>\", \"exec\")\n",
    "                        print(\"successfully compiled\")\n",
    "                    except SyntaxError as e:\n",
    "                        print(\"Syntax is invalid:\", e)\n",
    "                        return False\n",
    "            return True\n",
    "\n",
    "    def _format_actions(self, actions):\n",
    "        ans = []\n",
    "        for action in actions:\n",
    "            func = json.loads(action.function.arguments.strip('\"'))\n",
    "            func[\"action_name\"] = action.function.name\n",
    "\n",
    "            if func.get(\"action_name\") == \"remove_function\":\n",
    "                item = {\n",
    "                    \"action_name\": func.get(\"action_name\"),\n",
    "                    \"name\": func.get(\"name\"),\n",
    "                }\n",
    "            else:\n",
    "                item = {\n",
    "                    \"action_name\": func.get(\"action_name\"),\n",
    "                    \"name\": func.get(\"name\"),\n",
    "                    \"description\": func.get(\"description\"),\n",
    "                    \"arguments\": json.loads(func.get(\"arguments\").strip('\"')),\n",
    "                    \"packages\": func.get(\"packages\"),\n",
    "                    \"code\": func.get(\"code\"),\n",
    "                }\n",
    "            ans.append(item)\n",
    "        return ans\n",
    "\n",
    "    def _get_success_rate(self, statistic):\n",
    "        sum = 0\n",
    "        for key, value in statistic.items():\n",
    "            if \"is_correct\" not in value.keys():\n",
    "                statistic[key][\"is_correct\"] = 0\n",
    "        for key, value in statistic.items():\n",
    "            sum += value[\"is_correct\"]\n",
    "        if len(statistic.keys()) != 0:\n",
    "            success_rate = sum / len(statistic.keys())\n",
    "        else:\n",
    "            success_rate = None\n",
    "        return success_rate, statistic\n",
    "\n",
    "    def _modify_function_signiture(self, cur_functions, action_json):\n",
    "        for action in action_json:\n",
    "            action_name = action.get(\"action_name\")\n",
    "            if action_name != \"remove_function\":\n",
    "                cur_functions = [item for item in cur_functions if item[\"name\"] != action.get(\"name\")]\n",
    "                cur_functions.append(\n",
    "                    {\n",
    "                        \"name\": action.get(\"name\"),\n",
    "                        \"description\": action.get(\"description\"),\n",
    "                        \"arguments\": action.get(\"arguments\"),\n",
    "                        \"packages\": action.get(\"packages\"),\n",
    "                        \"code\": action.get(\"code\"),\n",
    "                    }\n",
    "                )\n",
    "            else:\n",
    "                cur_functions = [item for item in cur_functions if item[\"name\"] != action.get(\"name\")]\n",
    "        return cur_functions\n",
    "\n",
    "    def update_function_call(self, action, mathproxyagent, assistant):\n",
    "        def execute_func(name, packages, code, **args):\n",
    "            pip_install = (\n",
    "                f\"\"\"print(\"Installing package: {packages}\")\\nsubprocess.run([\"pip\", \"-qq\", \"install\", \"{packages}\"])\"\"\"\n",
    "                if packages\n",
    "                else \"\"\n",
    "            )\n",
    "            str = f\"\"\"\n",
    "import subprocess\n",
    "{pip_install}\n",
    "print(\"Result of {name} function execution:\")\n",
    "{code}\n",
    "args={args}\n",
    "result={name}(**args)\n",
    "if result is not None: print(result)\n",
    "\"\"\"\n",
    "            print(f\"execute_code:\\n{str}\")\n",
    "            result = execute_code(str)\n",
    "            if result[0] != 0:\n",
    "                raise Exception(\"Error in executing function:\" + result[1])\n",
    "            print(f\"Result: {result[1]}\")\n",
    "            return result[1]\n",
    "\n",
    "        name, description, arguments, packages, code, action_name = (\n",
    "            action.get(\"name\"),\n",
    "            action.get(\"description\"),\n",
    "            action.get(\"arguments\"),\n",
    "            action.get(\"packages\"),\n",
    "            action.get(\"code\"),\n",
    "            action.get(\"action_name\"),\n",
    "        )\n",
    "\n",
    "        if name in mathproxyagent._function_map.keys():\n",
    "            del mathproxyagent._function_map[name]\n",
    "        if action_name != \"remove_function\":\n",
    "            function_config = {\n",
    "                \"name\": name,\n",
    "                \"description\": description,\n",
    "                \"parameters\": {\"type\": \"object\", \"properties\": arguments},\n",
    "            }\n",
    "            mathproxyagent.register_function(\n",
    "                function_map={name: lambda **args: execute_func(name, packages, code, **args)}\n",
    "            )\n",
    "            assistant.update_function_signature(function_config, is_remove=False)\n",
    "        else:\n",
    "            assistant.update_function_signature(name, is_remove=True)\n",
    "\n",
    "    def step(self, history, statistic, func_signiture):\n",
    "        action_return = []\n",
    "        origin_signiture = func_signiture\n",
    "        modified_signiture = origin_signiture\n",
    "\n",
    "        success_rate, statistic = self._get_success_rate(statistic)  # TODO: make statistic feasible outside of the loop\n",
    "        for action_index in range(self._action_num):\n",
    "            prompt = self.OPT_PROMPT.format(\n",
    "                conversation_num=len(history),\n",
    "                statistic={\"is_correct\": statistic},\n",
    "                signiture=origin_signiture,\n",
    "                history=history,\n",
    "                success_rate=success_rate,\n",
    "                actions_num=action_index,\n",
    "                after_signiture=modified_signiture,\n",
    "            )\n",
    "            messages = [{\"role\": \"user\", \"content\": prompt}]\n",
    "            for _ in range(self._each_action_max_trials):\n",
    "                response = self._client.chat.completions.create(\n",
    "                    model=self.model,\n",
    "                    messages=messages,\n",
    "                    tools=[self.ADD_FUNC, self.REVISE_FUNC, self.REMOVE_FUNC],\n",
    "                    tool_choice=\"auto\",\n",
    "                )\n",
    "                actions = response.choices[0].message.tool_calls\n",
    "                if (\n",
    "                    self._val_json(actions)\n",
    "                    and self._val_syntax(actions)\n",
    "                    and self._val_remove(actions, modified_signiture)\n",
    "                ):\n",
    "                    break\n",
    "            if actions is not None:\n",
    "                action_result = self._format_actions(actions)\n",
    "                action_return = action_return + action_result\n",
    "                modified_signiture = self._modify_function_signiture(modified_signiture, action_result)\n",
    "        return action_return, modified_signiture"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MathUserProxy with function_call\n",
    "\n",
    "This agent is a customozied MathUserProxy inherits from its [partent class](https://github.com/microsoft/autogen/blob/main/autogen/agentchat/contrib/math_user_proxy_agent.py.) \n",
    "\n",
    "It supports using both function_call and python to solve math problems."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_termination_msg_mathchat(message):\n",
    "    if isinstance(message, dict):\n",
    "        message = message.get(\"content\")\n",
    "        if message is None:\n",
    "            return False\n",
    "    cb = extract_code(message)\n",
    "    contain_code = False\n",
    "    for c in cb:\n",
    "        if c[0] == \"python\" or c[0] == \"wolfram\":\n",
    "            contain_code = True\n",
    "            break\n",
    "    if message.rstrip().find(\"TERMINATE\") >= 0:\n",
    "        return True\n",
    "\n",
    "    return not contain_code and get_answer(message) is not None and get_answer(message) != \"\"\n",
    "\n",
    "\n",
    "class MathUserProxyAgent(MathUserProxyAgent):\n",
    "    MAX_CONSECUTIVE_AUTO_REPLY = 15\n",
    "    DEFAULT_REPLY = \"Continue. Please keep solving the problem until you need to query. (If you get to the answer, put it in \\\\boxed{}.)\"\n",
    "    PROMPTS = \"\"\"Let's solve a math problem.\n",
    "Query requirements:\n",
    "You should always use the 'print' function for the output and use fractions/radical forms instead of decimals.\n",
    "You can use packages like sympy to help you.\n",
    "You must follow the formats below to write your code:\n",
    "```python\n",
    "# your code\n",
    "```\n",
    "If some packages are missing, you could also suggest a code to install the corresponding package.\n",
    "\n",
    "Please follow this process:\n",
    "1. Solve the problem step by step (do not over-divide the steps).\n",
    "2. Take out any queries that can be asked through Python code (for example, any calculations or equations that can be calculated) and functions you know in the context of this conversation.\n",
    "\n",
    "Please\n",
    "(1) do not mix suggested Python codes and function calls in one step.\n",
    "(2) You MUST remember that you don’t have a function named \"python\" available.\n",
    "\n",
    "You must follow the formats below to write your Python code:\n",
    "```python\n",
    "# your code\n",
    "```\n",
    "\n",
    "3. Wait for me to give the results or wait for the executed results of the function call.\n",
    "4. Continue if you think the result is correct. If the result is invalid or unexpected, please correct your query or reasoning.\n",
    "\n",
    "After all the queries are run and you get the answer, put the answer in \\\\boxed{}.\n",
    "\n",
    "Problem:\n",
    "\"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        name: Optional[str] = \"MathChatAgent\",\n",
    "        is_termination_msg: Optional[Callable[[Dict], bool]] = is_termination_msg_mathchat,\n",
    "        human_input_mode: Optional[str] = \"NEVER\",\n",
    "        default_auto_reply: Optional[Union[str, Dict, None]] = DEFAULT_REPLY,\n",
    "        max_invalid_q_per_step=3,\n",
    "        **kwargs,\n",
    "    ):\n",
    "        super().__init__(\n",
    "            name=name,\n",
    "            is_termination_msg=is_termination_msg,\n",
    "            human_input_mode=human_input_mode,\n",
    "            default_auto_reply=default_auto_reply,\n",
    "            max_invalid_q_per_step=max_invalid_q_per_step,\n",
    "            **kwargs,\n",
    "        )\n",
    "        del self._reply_func_list[2]\n",
    "        self.register_reply([Agent, None], MathUserProxyAgent._generate_math_reply, position=4)\n",
    "        del self._reply_func_list[3]\n",
    "        self.register_reply(\n",
    "            trigger=autogen.ConversableAgent, reply_func=MathUserProxyAgent.generate_function_call_reply, position=3\n",
    "        )\n",
    "        self.register_reply(\n",
    "            trigger=autogen.ConversableAgent, reply_func=MathUserProxyAgent._check_final_result, position=0\n",
    "        )\n",
    "\n",
    "        self.max_function_call_trial = 3\n",
    "        self.query = None\n",
    "        self.answer = None\n",
    "\n",
    "    def generate_function_call_reply(\n",
    "        self,\n",
    "        messages: Optional[List[Dict]] = None,\n",
    "        sender: Optional[autogen.ConversableAgent] = None,\n",
    "        config: Optional[Any] = None,\n",
    "    ) -> Tuple[bool, Union[Dict, None]]:\n",
    "        \"\"\"Generate a reply using function call.\"\"\"\n",
    "        if messages is None:\n",
    "            messages = self._oai_messages[sender]\n",
    "        message = messages[-1]\n",
    "        if \"function_call\" in message:\n",
    "            is_exec_success, func_return = self.execute_function(message[\"function_call\"])\n",
    "            if is_exec_success:\n",
    "                self.max_function_call_trial = 3\n",
    "                return True, func_return\n",
    "            else:\n",
    "                if self.max_function_call_trial == 0:\n",
    "                    error_message = func_return[\"content\"]\n",
    "                    self.logs[\"is_correct\"] = 0\n",
    "                    self.max_function_call_trial = 3\n",
    "                    return (\n",
    "                        True,\n",
    "                        \"The func is executed failed many times. \"\n",
    "                        + error_message\n",
    "                        + \". Please directly reply me with TERMINATE. We need to terminate the conversation.\",\n",
    "                    )\n",
    "                else:\n",
    "                    revise_prompt = \"You may make a wrong function call (It may due the arguments you provided doesn't fit the function arguments like missing required positional argument). \\\n",
    "                    If you think this error occurs due to you make a wrong function arguments input and you could make it success, please try to call this function again using the correct arguments. \\\n",
    "                    Otherwise, the error may be caused by the function itself. Please directly reply me with TERMINATE. We need to terminate the conversation. \"\n",
    "                    error_message = func_return[\"content\"]\n",
    "                    return True, \"The func is executed failed.\" + error_message + revise_prompt\n",
    "        return False, None\n",
    "\n",
    "    def initiate_chat(\n",
    "        self,\n",
    "        recipient,\n",
    "        query: None,\n",
    "        answer: None,\n",
    "        silent: Optional[bool] = False,\n",
    "        **context,\n",
    "    ):\n",
    "        self.query = query\n",
    "        if not isinstance(answer, str):\n",
    "            answer = str(answer)\n",
    "            if answer.endswith(\".0\"):\n",
    "                answer = answer[:-2]\n",
    "            self._answer = answer\n",
    "        else:\n",
    "            self._answer = answer\n",
    "        self.logs = {}\n",
    "        self._prepare_chat(recipient, True)\n",
    "\n",
    "        chat_history = []\n",
    "        error_message = None\n",
    "\n",
    "        try:\n",
    "            prompt = self.PROMPTS + context[\"problem\"]\n",
    "            self.send(prompt, recipient, silent=silent)\n",
    "        except BadRequestError as e:\n",
    "            error_message = str(e)\n",
    "            self.logs[\"is_correct\"] = 0\n",
    "            print(\"error information: {}\".format(error_message))\n",
    "\n",
    "        key = list(self.chat_messages.keys())[0]\n",
    "        chat_messages = self.chat_messages[key]\n",
    "        for item in chat_messages:\n",
    "            chat_history.append(item)\n",
    "        if error_message is not None:\n",
    "            chat_history.append(error_message)\n",
    "        recipient.reset()\n",
    "        self.reset()\n",
    "        return self.logs, chat_history\n",
    "\n",
    "    def _check_final_result(\n",
    "        self,\n",
    "        messages: Optional[List[Dict]] = None,\n",
    "        sender: Optional[autogen.Agent] = None,\n",
    "        config: Optional[Any] = None,\n",
    "    ):\n",
    "        messages = messages[-1]\n",
    "\n",
    "        if isinstance(messages, dict):\n",
    "            messages = messages.get(\"content\")\n",
    "            if messages is None:\n",
    "                return False, None\n",
    "\n",
    "        cb = extract_code(messages)\n",
    "        contain_code = False\n",
    "        for c in cb:\n",
    "            if c[0] == \"python\" or c[0] == \"wolfram\":\n",
    "                contain_code = True\n",
    "                break\n",
    "        if not contain_code and get_answer(messages) is not None and get_answer(messages) != \"\":\n",
    "            if get_answer(messages) == self._answer:\n",
    "                self.logs[\"is_correct\"] = 1\n",
    "                return True, \"The result is Correct. Please reply me with TERMINATE.\"\n",
    "            else:\n",
    "                self.logs[\"is_correct\"] = 0\n",
    "                return False, None\n",
    "        else:\n",
    "            return False, None\n",
    "\n",
    "    def _reset(self):\n",
    "        self._valid_q_count = 0\n",
    "        self._total_q_count = 0\n",
    "        self._accum_invalid_q_per_step = 0\n",
    "        self._previous_code = \"\"\n",
    "        self.last_reply = None\n",
    "\n",
    "        self.query = None\n",
    "        self.answer = None\n",
    "        self.logs = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load dataset\n",
    "\n",
    "MATAH dataset contains 12,500 challenging competition mathematics problems. Each problem in MATH has a full step-by-step solution which can be used to teach models to generate answer derivations and explanations. \n",
    "\n",
    "We strctly follow the train/test splits of [Craft](https://github.com/lifan-yuan/CRAFT). Please specific your own path to the dataset. Here we sample the first 10 algebra problems as examples. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data, train_data = [], []\n",
    "with open(\"MATH/dataset/algebra.jsonl\", \"r\", encoding=\"utf-8\") as f:\n",
    "    for line in f:\n",
    "        test_data.append(json.loads(line))\n",
    "with open(\"MATH/dataset/train/algebra.jsonl\", \"r\", encoding=\"utf-8\") as f:\n",
    "    for line in f:\n",
    "        train_data.append(json.loads(line))\n",
    "test_data, train_data = test_data[0:10], train_data[0:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Agents construction\n",
    "\n",
    "Constructing MathUserProxyAgent and AssistantAgent used in solving these problems. Here, we use gpt-4-1106-preview to construct the AssistantAgent. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_list = autogen.config_list_from_json(\n",
    "    \"OAI_CONFIG_LIST\",\n",
    ")\n",
    "mathproxyagent = MathUserProxyAgent(\n",
    "    name=\"mathproxyagent\",\n",
    "    human_input_mode=\"NEVER\",\n",
    "    code_execution_config={\"work_dir\": \"_output\", \"use_docker\": False},\n",
    "    is_termination_msg=is_termination_msg_mathchat,\n",
    ")\n",
    "assistant = autogen.AssistantAgent(\n",
    "    name=\"assistant\",\n",
    "    system_message=\"You are a helpful assistant.\",\n",
    "    llm_config={\n",
    "        \"timeout\": 600,\n",
    "        \"seed\": 42,\n",
    "        \"config_list\": config_list,\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test without agent optimizations \n",
    "\n",
    "Below is the code to get the performance without the agents optimization process. \n",
    "\n",
    "In this case, the AssistantAgent and MathUserProxyAgent don't have any function calls but solely solve problems with Python."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "history_list, statistic_list = {}, {}\n",
    "for index, query in enumerate(test_data):\n",
    "    single_statistic, chat_history = mathproxyagent.initiate_chat(\n",
    "        recipient=assistant, answer=query[\"answer\"], query=[\"question\"], problem=query[\"question\"]\n",
    "    )\n",
    "    history_list[\"conversation: {index}\".format(index=index + 1)] = chat_history\n",
    "    statistic_list[\"conversation: {index}\".format(index=index + 1)] = single_statistic\n",
    "\n",
    "sum = 0\n",
    "for key, value in statistic_list.items():\n",
    "    if \"is_correct\" not in value.keys():\n",
    "        statistic_list[key][\"is_correct\"] = 0\n",
    "for key, value in statistic_list.items():\n",
    "    sum += value[\"is_correct\"]\n",
    "\n",
    "success_rate_without_agent_training = sum / len(statistic_list.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Agent optimizing \n",
    "\n",
    "Then, use the AgentOptimizer to iteratively optimize the agents by optimizing the function calls according to the historical conversations and performance.    \n",
    "\n",
    "Here we optimize these two agents for ten epochs. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCH = 10\n",
    "OAI_config = {\n",
    "    \"AZURE_OPENAI_API_KEY\": \"gpt-4-1106-preview\",\n",
    "    \"api_version\": \"2023-12-01-preview\",\n",
    "    \"azure_endpoint\": \"your_azure_endpoint\",\n",
    "    \"model\": \"your_model\",\n",
    "}\n",
    "agent_optimizer = AgentOptimizer(OAI_config=OAI_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "history_list, statistic_list, function_json, agent_list = [], [], [], []\n",
    "for epoch in range(EPOCH):\n",
    "    if len(history_list) != 0:\n",
    "        actions, function_json = agent_optimizer.step(history_list, statistic_list, function_json)\n",
    "        for action in actions:\n",
    "            agent_optimizer.update_function_call(action, mathproxyagent=mathproxyagent, assistant=assistant)\n",
    "    history_list, statistic_list = {}, {}\n",
    "    for index, query in enumerate(train_data):\n",
    "        single_statistic, chat_history = mathproxyagent.initiate_chat(\n",
    "            recipient=assistant, answer=query[\"answer\"], query=[\"question\"], problem=query[\"question\"]\n",
    "        )\n",
    "        history_list[\"conversation: {index}\".format(index=index + 1)] = chat_history\n",
    "        statistic_list[\"conversation: {index}\".format(index=index + 1)] = single_statistic\n",
    "\n",
    "    sum = 0\n",
    "    for key, value in statistic_list.items():\n",
    "        if \"is_correct\" not in value.keys():\n",
    "            statistic_list[key][\"is_correct\"] = 0\n",
    "    for key, value in statistic_list.items():\n",
    "        sum += value[\"is_correct\"]\n",
    "    print(\"Train_Epoch_{epoch_num}_Success_Rate: {average}%\".format(epoch_num=epoch, average=sum / len(statistic_list)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test with agent optimizations \n",
    "\n",
    "After agent optimization, the agents obtained a list of functions from the AgentOptimizers after 10 optimization iterations as shown below.\n",
    "\n",
    "We then show the final performances with/without the agent optimization process. We observe the agents after optimization are obviously better.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "history_list, statistic_list = {}, {}\n",
    "for index, query in enumerate(test_data):\n",
    "    single_statistic, chat_history = mathproxyagent.initiate_chat(\n",
    "        recipient=assistant, answer=query[\"answer\"], query=[\"question\"], problem=query[\"question\"]\n",
    "    )\n",
    "    history_list[\"conversation: {index}\".format(index=index + 1)] = chat_history\n",
    "    statistic_list[\"conversation: {index}\".format(index=index + 1)] = single_statistic\n",
    "\n",
    "sum = 0\n",
    "for key, value in statistic_list.items():\n",
    "    if \"is_correct\" not in value.keys():\n",
    "        statistic_list[key][\"is_correct\"] = 0\n",
    "for key, value in statistic_list.items():\n",
    "    sum += value[\"is_correct\"]\n",
    "\n",
    "success_rate_with_agent_training = sum / len(statistic_list.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------------------------Functions learned------------------------------------------------\n",
      "{'name': 'evaluate_expression', 'description': 'Evaluate arithmetic or mathematical expressions provided as strings.', 'arguments': {'expression': {'type': 'string', 'description': 'The mathematical expression to evaluate.'}}, 'packages': 'sympy', 'code': 'from sympy import sympify, SympifyError\\n\\ndef evaluate_expression(expression):\\n    try:\\n        result = sympify(expression)\\n        if result.is_number:\\n            result = float(result)\\n        else:\\n            result = str(result)\\n        return result\\n    except SympifyError as e:\\n        return str(e)'}\n",
      "{'name': 'calculate_compound_interest_principal', 'description': 'Calculate the principal amount needed to achieve a certain future value with quarterly compound interest.', 'arguments': {'future_value': {'type': 'number', 'description': 'The total amount of money desired in the future.'}, 'annual_interest_rate': {'type': 'number', 'description': 'The annual interest rate in decimal form.'}, 'compounding_periods': {'type': 'integer', 'description': 'The number of times interest is compounded per year.'}, 'years': {'type': 'number', 'description': 'The time in years the money is invested for.'}}, 'packages': 'sympy', 'code': \"from sympy import symbols, solve, N\\n\\nfuture_value, annual_interest_rate, compounding_periods, years = symbols('future_value annual_interest_rate compounding_periods years')\\nP = symbols('P')\\nequation = Eq(future_value, P * (1 + annual_interest_rate / compounding_periods) ** (compounding_periods * years))\\nprincipal = solve(equation, P)[0]\\nprincipal = N(principal, chop=True).evalf()\"}\n",
      "{'name': 'solve_linear_system', 'description': 'Solve a system of linear equations represented as coefficients and variables.', 'arguments': {'equations': {'type': 'array', 'items': {'type': 'array', 'items': {'type': 'number'}}}, 'variables': {'type': 'array', 'items': {'type': 'string'}}}, 'packages': 'sympy', 'code': 'from sympy import Matrix, symbols, linsolve\\n\\ndef solve_linear_system(equations, variables):\\n    if not equations or not all(len(eq) == len(variables) + 1 for eq in equations):\\n        return \"Error: Equations list is empty or not all equations have the correct length.\"\\n    if len(variables) != len(set(variables)):\\n        return \"Error: Duplicate symbols in the \\'variables\\' list.\"\\n    try:\\n        sym_vars = symbols(\\' \\'.join(variables))\\n        matrix = Matrix(equations)\\n        system = (matrix, sym_vars)\\n        solution = linsolve(system)\\n        solution_list = list(solution)\\n        if not solution_list:\\n            return \"Error: No solution exists for the given system of equations.\"\\n        result = solution_list[0] if solution_list else []\\n        return dict(zip(variables, result))\\n    except Exception as e:\\n        return f\"Error: {str(e)}\"'}\n",
      "------------------------------------------------Summary------------------------------------------------\n",
      "\n",
      "success_rate_without_agent_training: 60.0%\n",
      "\n",
      "success_rate_with_agent_training: 90.0%\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    \"------------------------------------------------Functions learned------------------------------------------------\"\n",
    ")\n",
    "for function in function_json:\n",
    "    print(function)\n",
    "print(\"------------------------------------------------Summary------------------------------------------------\\n\")\n",
    "print(\"success_rate_without_agent_training: {average}%\\n\".format(average=success_rate_without_agent_training * 100))\n",
    "print(\"success_rate_with_agent_training: {average}%\\n\".format(average=success_rate_with_agent_training * 100))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.9",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
